/* -*-C++-*- */
/* This file contains the IR classes for all expressions.
   The base classes are in base.def */

/** \addtogroup irdefs
  * @{
  */

abstract Operation_Unary : Operation {
    Expression expr;
    Operation_Unary {
        if (!srcInfo && expr) srcInfo = expr->srcInfo;
        if (type->is<Type::Unknown>() && expr) type = expr->type; }
    precedence = DBPrint::Prec_Prefix;
}

class Neg : Operation_Unary {
    stringOp = "-";
}

class Cmpl : Operation_Unary {
    stringOp = "~";
}

class LNot : Operation_Unary {
    stringOp = "!";
}

abstract Operation_Binary : Operation {
    Expression left;
    Expression right;
    Operation_Binary {
        if (!srcInfo && left && right) srcInfo = left->srcInfo + right->srcInfo;
        if (type->is<Type::Unknown>() && left && right && left->type == right->type)
            type = left->type; }
}

abstract Operation_Ternary : Operation {
    Expression e0;
    Expression e1;
    Expression e2;
    Operation_Ternary { if (!srcInfo && e0 && e2) srcInfo = e0->srcInfo + e2->srcInfo; }
}

abstract Operation_Relation : Operation_Binary {
    Operation_Relation { type = Type::Boolean::get(); }
}

class Mul : Operation_Binary {
    stringOp = "*";
    precedence = DBPrint::Prec_Mul;
}

class Div : Operation_Binary {
    stringOp = "/";
    precedence = DBPrint::Prec_Div;
}
class Mod : Operation_Binary {
    stringOp = "%";
    precedence = DBPrint::Prec_Mod;
}
class Add : Operation_Binary {
    stringOp = "+";
    precedence = DBPrint::Prec_Add;
}
class Sub : Operation_Binary {
    stringOp = "-";
    precedence = DBPrint::Prec_Sub;
}
class AddSat : Operation_Binary {
    stringOp = "|+|";
    precedence = DBPrint::Prec_AddSat;
}
class SubSat : Operation_Binary {
    stringOp = "|-|";
    precedence = DBPrint::Prec_SubSat;
}
class Shl : Operation_Binary {
    stringOp = "<<";
    precedence = DBPrint::Prec_Shl;
    Shl { if (type->is<Type::Unknown>() && left) type = left->type; }
}
class Shr : Operation_Binary {
    stringOp = ">>";
    precedence = DBPrint::Prec_Shr;
    Shr { if (type->is<Type::Unknown>() && left) type = left->type; }
}
class Equ : Operation_Relation {
    stringOp = "==";
    precedence = DBPrint::Prec_Equ;
}
class Neq : Operation_Relation {
    stringOp = "!=";
    precedence = DBPrint::Prec_Neq;
}
class Lss : Operation_Relation {
    stringOp = "<";
    precedence = DBPrint::Prec_Lss;
}
class Leq : Operation_Relation {
    stringOp = "<=";
    precedence = DBPrint::Prec_Leq;
}
class Grt : Operation_Relation {
    stringOp = ">";
    precedence = DBPrint::Prec_Grt;
}
class Geq : Operation_Relation {
    stringOp = ">=";
    precedence = DBPrint::Prec_Geq;
}
class BAnd : Operation_Binary {
    stringOp = "&";
    precedence = DBPrint::Prec_BAnd;
}
class BOr : Operation_Binary {
    stringOp = "|";
    precedence = DBPrint::Prec_BOr;
}
class BXor : Operation_Binary {
    stringOp = "^";
    precedence = DBPrint::Prec_BXor;
}
class LAnd : Operation_Binary {
    stringOp = "&&";
    precedence = DBPrint::Prec_LAnd;
}
class LOr : Operation_Binary {
    stringOp = "||";
    precedence = DBPrint::Prec_LOr;
}

abstract Literal : Expression, CompileTimeValue {
#nodbprint
}

/// This is an integer literal on arbitrary-precision.
class Constant : Literal {
    big_int value;
    optional unsigned  base;  /// base used when reading/writing
#noconstructor
    /// if noWarning is true, no warning is emitted
    void handleOverflow(bool noWarning);
    // We need to enumerate all the integer types because we need proper 64-bit handling on
    // 32-bit systems (which ain't long!) and mpz_import is too big a hammer because and it loses
    // the signess of the value.
    Constant(int v, unsigned base = 10) :
        Literal(new Type_InfInt()), value(v), base(base) {}
    Constant(unsigned v, unsigned base = 10) :
        Literal(new Type_InfInt()), value(v), base(base) {}
#emit
#if __WORDSIZE == 64
    Constant(intmax_t v, unsigned base = 10) :
        Literal(new Type_InfInt()), value(v), base(base) {}
#else
    Constant(long v, unsigned base = 10) :
        Literal(new Type_InfInt()), value(v), base(base) {}
    Constant(unsigned long v, unsigned base = 10) :
        Literal(new Type_InfInt()), value(v), base(base) {}
    Constant(intmax_t v, unsigned base = 10) :
        Literal(new Type_InfInt()), value(v), base(base) {}
#endif
#end
    Constant(uint64_t v, unsigned base = 10) :
        Literal(new Type_InfInt()), value(v), base(base) {}
    Constant(big_int v, unsigned base = 10) :
        Literal(new Type_InfInt()), value(v), base(base) {}
    Constant(Util::SourceInfo si, big_int v, unsigned base = 10) :
        Literal(si, new Type_InfInt()), value(v), base(base) {}
    Constant(const Type *t, big_int v, unsigned base = 10, bool noWarning = false) :
        Literal(t), value(v), base(base) { CHECK_NULL(t); handleOverflow(noWarning); }
    Constant(Util::SourceInfo si, const Type *t, big_int v,
             unsigned base = 10, bool noWarning = false) :
        Literal(si, t), value(v), base(base) { CHECK_NULL(t); handleOverflow(noWarning); }
#emit
    static Constant GetMask(unsigned width);
#end
    bool fitsInt() const { return value >= INT_MIN && value <= INT_MAX; }
    bool fitsLong() const { return value >= LONG_MIN && value <= LONG_MAX; }
    bool fitsUint() const { return value >= 0 && value <= UINT_MAX; }
    bool fitsUint64() const { return value >= 0 && value <= UINT64_MAX; }
    bool fitsInt64() const { return value >= INT64_MIN && value <= INT64_MAX; }
    long asLong() const {
        if (!fitsLong())
            ::error(ErrorType::ERR_OVERLIMIT, "%1$x: Value too large for long", this);
        return static_cast<long>(value); }
    int asInt() const {
        if (!fitsInt())
            ::error(ErrorType::ERR_OVERLIMIT, "%1$x: Value too large for int", this);
        return static_cast<int>(value); }
    unsigned asUnsigned() const {
        if (!fitsUint())
            ::error(ErrorType::ERR_OVERLIMIT, "%1$x: Value too large for unsigned int", this);
        return static_cast<unsigned>(value);
    }
    uint64_t asUint64() const {
        if (!fitsUint64())
            ::error(ErrorType::ERR_OVERLIMIT, "%1$x: Value too large for uint64", this);
        return static_cast<uint64_t>(value);
    }
    int64_t asInt64() const {
        if (!fitsInt64())
            ::error(ErrorType::ERR_OVERLIMIT, "%1$x: Value too large for int64", this);
        return static_cast<int64_t>(value);
    }
    // The following operators are only used in special circumstances.
    // They do not look at the type when operating.  There are separate
    // implementations of these computations when doing proper constant folding.
#emit
    Constant operator<<(const unsigned &shift) const;
    Constant operator>>(const unsigned &shift) const;
    Constant operator&(const Constant &c) const;
    Constant operator|(const Constant &c) const;
    Constant operator^(const Constant &c) const;
    Constant operator-(const Constant &c) const;
    Constant operator-() const;
#end
    toString {
        const IR::Type_Bits* tb = dynamic_cast<const IR::Type_Bits*>(type);
        unsigned width;
        bool sign;
        if (tb == nullptr) {
            width = 0;
            sign = false;
        } else {
            width = tb->size;
            sign = tb->isSigned;
        }
        return Util::toString(value, width, sign, base);
    }
    visit_children { v.visit(type, "type"); }
}

class BoolLiteral : Literal {
    bool value;
    toString{ return value ? "true" : "false"; }
}

class StringLiteral : Literal {
    cstring value;
    validate{ if (value.isNull()) BUG("null StringLiteral"); }
    toString{ return cstring("\"") + value + "\""; }
    StringLiteral(ID v) : Literal(v.srcInfo), value(v.name) {}
#emit
    operator IR::ID() const { return IR::ID(srcInfo, value); }
#end
}

class PathExpression : Expression {
    Path path;
    PathExpression { if (!srcInfo && path) srcInfo = path->srcInfo; }
    PathExpression(IR::ID id) : Expression(id.srcInfo), path(new IR::Path(id)) {}
    toString{ return path->toString(); }
}

// enum X { a }
// X.a
// The 'X' portion is a TypeNameExpression
class TypeNameExpression : Expression {
    Type_Name typeName;
    TypeNameExpression { if (!srcInfo && typeName) srcInfo = typeName->srcInfo; }
    TypeNameExpression(ID id) : Expression(id.srcInfo),
                                typeName(new IR::Type_Name(new IR::Path(id))) {}
    dbprint{ out << typeName; }
    toString { return typeName->toString(); }
}

class Slice : Operation_Ternary {
    precedence = DBPrint::Prec_Postfix;
    stringOp = "[:]";
    toString{ return e0->toString() + "[" + e1->toString() + ":" + e2->toString() + "]"; }
    // After type checking e1 and e2 will be constants
    unsigned getH() const { return e1->to<IR::Constant>()->asUnsigned(); }
    unsigned getL() const { return e2->to<IR::Constant>()->asUnsigned(); }
    Slice(Expression a, int hi, int lo)
    : Operation_Ternary(IR::Type::Bits::get(hi-lo+1), a, new Constant(hi), new Constant(lo)) {}
    Slice(Util::SourceInfo si, Expression a, int hi, int lo)
    : Operation_Ternary(si, IR::Type::Bits::get(hi-lo+1), a, new Constant(hi), new Constant(lo)) {}
    Slice {
        if (type->is<Type::Unknown>() && e1 && e1->is<Constant>() && e2 && e2->is<Constant>())
            type = IR::Type::Bits::get(getH() - getL() + 1); }
    // make a slice, folding slices on slices and slices on constants automatically
    static Expression make(Expression a, unsigned hi, unsigned lo);
}

class Member : Operation_Unary {
    precedence = DBPrint::Prec_Postfix;
    ID member;
    virtual int offset_bits() const;
    int lsb() const;
    int msb() const;
    stringOp = ".";
    toString{ return expr->toString() + "." + member; }
}

class Concat : Operation_Binary {
    stringOp = "++";
    precedence = DBPrint::Prec_Add;
    Concat {
        if (left && right) {
            auto lt = left->type->to<IR::Type::Bits>();
            auto rt = right->type->to<IR::Type::Bits>();
            if (lt && rt)
                type = IR::Type::Bits::get(lt->size + rt->size, lt->isSigned); } }
}

class ArrayIndex : Operation_Binary {
    stringOp = "[]";
    precedence = DBPrint::Prec_Postfix;
    ArrayIndex {
        if (auto st = left ? left->type->to<IR::Type_Stack>() : nullptr)
            type = st->elementType; }
    toString{ return left->toString() + "[" + right->toString() + "]"; }
}

class Range : Operation_Binary {
    stringOp = "..";
    precedence = DBPrint::Prec_Low;
    Range { if (left && type == left->type && !left->type->is<Type::Unknown>())
                type = new Type_Set(left->type); }
}

class Mask : Operation_Binary {
    stringOp = "&&&";
    precedence = DBPrint::Prec_Low;
    Mask { if (left && type == left->type && !left->type->is<Type::Unknown>())
                type = new Type_Set(left->type); }
}

class Mux : Operation_Ternary {
    stringOp = "?:";
    precedence = DBPrint::Prec_Low;
    visit_children {
        v.visit(e0, "e0");
        auto &clone(v.flow_clone());
        v.visit(e1, "e1");
        clone.visit(e2, "e2");
        v.flow_merge(clone); }
    Mux { if (type->is<Type::Unknown>() && e1 && e2 && e1->type == e2->type) type = e1->type; }
}

class DefaultExpression : Expression {}

// Two different This should not be equal.
// That's why we use a hidden id field to distinguish them.
class This : Expression {
    int id = nextId++;
 private:
    static int nextId;
}  // experimental

class Cast : Operation_Unary {
    Type destType = type;
    /// These will generally always be the same, except when a cast to a type argument of
    /// a generic occurs.  Then at some point, the 'destType' will be specialized to a concrete
    /// type, and 'type' will only be updated later when type inferencing occurs
    precedence = DBPrint::Prec_Prefix;
    stringOp = "(cast)";
    toString{ return "cast"; }
    validate{ BUG_CHECK(!destType->is<Type_Unknown>(), "%1%: Cannot cast to unknown type", this); }
}

class SelectCase {
    Expression     keyset;
    PathExpression state;
    dbprint { out << keyset << ": " << state; }
}

class SelectExpression : Expression {
    ListExpression            select;
    inline Vector<SelectCase> selectCases;
    visit_children {
        v.visit(select, "select");
        v.parallel_visit(selectCases, "selectCases"); }
}

class MethodCallExpression : Expression {
    Expression                  method;
    optional Vector<Type>       typeArguments = new Vector<Type>;
    optional Vector<Argument>   arguments = new Vector<Argument>;
    toString{ return method->toString(); }
    validate{ typeArguments->check_null(); arguments->check_null(); }
    MethodCallExpression(Util::SourceInfo si, IR::ID m, const std::initializer_list<Argument> &a)
            : Expression(si), method(new PathExpression(m)), arguments(new Vector<Argument>(a)) {}
    MethodCallExpression(Util::SourceInfo si, Expression m,
                         const std::initializer_list<Argument> &a)
    : Expression(si), method(m), arguments(new Vector<Argument>(a)) {}
    MethodCallExpression(Expression m, const std::initializer_list<const Expression *> &a)
    : method(m), arguments(nullptr)  {
        auto arguments = new Vector<Argument>;
        for (auto arg : a) arguments->push_back(new Argument(arg));
        this->arguments = arguments; }
}

class ConstructorCallExpression : Expression {
    Type               constructedType = type;  // Either a Type_Name or a Specialized_Type
    Vector<Argument>   arguments;
    toString{ return constructedType->toString(); }
    validate{ BUG_CHECK(constructedType->is<Type_Name>() ||
                        constructedType->is<Type_Specialized>(),
                        "%1%: unexpected type", constructedType);
        arguments->check_null(); }
}

/// Represents a list of expressions separated by commas
class ListExpression : Expression {
    inline Vector<Expression> components;
    ListExpression {
        validate();
        if (type->is<Type::Unknown>()) {
            Vector<Type> tuple;
            for (auto e : components)
                tuple.push_back(e->type);
            type = new Type_List(tuple); } }
    validate { components.check_null(); }
    size_t size() const { return components.size(); }
    void push_back(Expression e) { components.push_back(e); }
}

/// An expression that evaluates to a struct.
class StructExpression : Expression {
    /// The struct or header type that is being intialized.
    /// May only be known after type checking; so it can be nullptr.
    NullOK Type structType;
    inline IndexedVector<NamedExpression> components;
    validate {
        components.check_null(); components.validate();
        BUG_CHECK(structType == nullptr || structType->is<IR::Type_Name>() ||
                  structType->is<IR::Type_Specialized>(),
                  "%1%: unexpected struct type", this);
    }
    size_t size() const { return components.size(); }
}

/// A ListExpression where all the components are compile-time values.
/// This is used by the evaluator pass.
class ListCompileTimeValue : CompileTimeValue {
    inline Vector<Node> components;
    validate {
        for (auto v : components)
            BUG_CHECK(v->is<CompileTimeValue>(), "%1%: not a compile-time value", v); }
#nodbprint
}

/// Experimental: an extern methond/function call with constant arguments to be
/// evaluated at compile time
class CompileTimeMethodCall : MethodCallExpression, CompileTimeValue {
    CompileTimeMethodCall(MethodCallExpression e) : MethodCallExpression(*e) {}
    validate {
        for (auto v : *arguments)
            BUG_CHECK(v->is<CompileTimeValue>(), "%1%: not a compile-time value", v); }
#nodbprint
}

/** @} *//* end group irdefs */
